"""Semi-supervised learning for EM for GMM."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

from utils import io_tools
from models.gaussian_mixture_model import GaussianMixtureModel

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('max_iter', 15, 'Number of EM steps to run.')
flags.DEFINE_integer('n_components', 10, 'Number of components')

def main(_):
    """High level pipeline.

    This scripts performs the training and evaling and testing stages for
    semi-supervised learning using kMeans algorithm.
    """
    # Load dataset.
    path1_a = 'data/simple_test.csv'
    path1_b = 'data/simple_test.csv'

    path2_a = 'data/mnist_train.csv'
    path2_b = 'data/mnist_test.csv'

    _, unlabeled_data = io_tools.read_dataset(path2_a)
    n_dims = unlabeled_data.shape[1]


    # Initialize model.
    model = GaussianMixtureModel(n_dims, n_components=FLAGS.n_components,
                                     max_iter=FLAGS.max_iter, reg_covar=10)

    # Unsupervised training.
    print("Unsupervised training...")
    model.fit(unlabeled_data)

    # Supervised training.
    print("Supervised training...")
    train_label, train_data = io_tools.read_dataset(path2_a)
    model.supervised_fit(train_data, train_label)

    # Eval model.
    print("Eval model...")
    eval_label, eval_data = io_tools.read_dataset(path2_b)
    y_hat_eval = model.supervised_predict(eval_data)

    acc = np.sum(y_hat_eval == eval_label) / (1.*eval_data.shape[0])
    
    print("------ Results -------")
    print("Accuracy: %s" % acc)

    print("Check internal fields:")
    print("labels: {}".format(model.cluster_label_map))
    print("check sum pi tp 1:")
    print(np.sum(model._pi))



if __name__ == '__main__':
    tf.app.run()
